import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn


class BarlowTwinsLoss(nn.Module):
    def __init__(self, lamb=5e-3, tau=0.0):
        super(BarlowTwinsLoss, self).__init__()
        self.lamb = lamb
        self.tau = tau

    def forward(self, q, k):
        # empirical cross-correlation matrix
        q = (q - q.mean(0)) / q.std(0)
        k = (k - k.mean(0)) / q.std(0)
        c = q.T @ k

        # sum the cross-correlation matrix between all gpus
        c.div_(q.size(0))
        
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c + self.tau).pow_(2).sum()
        loss = on_diag + self.lamb * off_diag
        return loss, {"part1": on_diag, "part2": off_diag}


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class MixedSpectralContrastiveLoss(nn.Module):
    def __init__(self, mu=1.0, tau=0.0, lamb=1.0):
        super().__init__()
        self.lamb = lamb
        self.ssl_criterion = SpectralContrastiveLoss(mu=mu, tau=tau)
        self.sup_criterion = SupervisedSpectralContrastiveLoss(mu=mu, tau=tau)
    
    def forward(self, z1, z2, labels):
        ssl_loss, _ = self.ssl_criterion(z1, z2)
        sup_loss, _ = self.sup_criterion(z1, z2, labels)
        return ssl_loss + self.lamb * sup_loss, {"part1": ssl_loss, "part2": sup_loss}


class MixedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.2, lamb=0.1):
        super().__init__()
        self.lamb = lamb
        self.ssl_criterion = ContrastiveLoss(temperature=temperature)
        self.sup_criterion = SupervisedContrastiveLoss(temperature=temperature)
    
    def forward(self, z1, z2, labels):
        ssl_loss, _ = self.ssl_criterion(z1, z2)
        sup_loss, _ = self.sup_criterion(z1, z2, labels)
        return ssl_loss + self.lamb * sup_loss, {"part1": ssl_loss, "part2": sup_loss}


class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, z1, z2):
        assert z1.shape == z2.shape and len(z1.shape)==2
        device = z1.device
        N = z1.shape[0]

        features = torch.cat([z1, z2], dim=0)
        features = F.normalize(features, dim=-1)
        logits = torch.matmul(features, features.T).div(self.temperature)
        # mask-out negative pairs
        mask = torch.eye(N, device=device).repeat(2, 2)
        # mask-out self-contrast cases
        logits_mask = 1 - torch.eye(2*N, device=device)
        mask = mask * logits_mask

        loss_part1 = (logits * mask).sum() / mask.sum()
        loss_part2 = (((logits.exp() * logits_mask)).sum(dim=-1, keepdim=True).log() * mask).sum() / mask.sum()
        loss = -loss_part1 + loss_part2
        return loss, {"part1": loss_part1, "part2": loss_part2}


class SupervisedContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupervisedContrastiveLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, z1, z2, labels):
        assert z1.shape == z2.shape and len(z1.shape)==2
        device = z1.device
        N = z1.shape[0]

        features = torch.cat([z1, z2], dim=0)
        features = F.normalize(features, dim=-1)
        logits = torch.matmul(features, features.T).div(self.temperature)
        # mask-out negative pairs
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device).repeat(2, 2)
        # mask-out self-contrast cases
        logits_mask = 1. - torch.eye(2*N, device=device)
        mask = mask * logits_mask

        loss_part1 = (logits * mask).sum() / mask.sum()
        loss_part2 = (((logits.exp() * logits_mask)).sum(dim=-1, keepdim=True).log() * mask).sum() / mask.sum()
        loss = -loss_part1 + loss_part2
        return loss, {"part1": loss_part1, "part2": loss_part2}


class SpectralContrastiveLoss(nn.Module):
    def __init__(self, mu=1.0, tau=0.0):
        super().__init__()
        self.tau = tau
        self.mu = mu
    
    def forward(self, z1, z2):
        assert z1.shape == z2.shape and len(z1.shape)==2
        device = z1.device
        N = z1.shape[0]
        z = torch.cat([z1, z2], dim=0)

        z = F.normalize(z, dim=1)
        logits = torch.matmul(z, z.T)

        # mask-out negative pair s
        mask = torch.eye(N, dtype=torch.float32, device=device).repeat(2, 2)

        loss_part1 = -2 * (logits * mask).sum() / mask.sum()
        loss_part2 = self.mu * ((logits + self.tau) ** 2).mean()
        return loss_part1 + loss_part2, {"part1": loss_part1, "part2": loss_part2}



class SpectralContrastiveLossV1(nn.Module):
    def __init__(self, mu=1.0, tau=0.0):
        super().__init__()
        self.tau = tau
        self.mu = mu
    
    def forward(self, z1, z2):
        assert z1.shape == z2.shape and len(z1.shape)==2
        device = z1.device
        N, D = z1.shape

        mask1 = (z1.norm(dim=1) < np.sqrt(self.mu)).float().unsqueeze(1)
        mask2 = (z2.norm(dim=1) < np.sqrt(self.mu)).float().unsqueeze(1)

        z1 = mask1 * z1 + (1 - mask1) * F.normalize(z1, dim=1) * np.sqrt(self.mu)
        z2 = mask2 * z2 + (1 - mask2) * F.normalize(z2, dim=1) * np.sqrt(self.mu)

        R0 = ((z1 - z2) **2).sum(dim=-1).mean(dim=0)
        R1 = (((z1.T @ z1 + z2.T @ z2) / (2 * N) - torch.eye(D, device=device))*2).sum()
        R2 = ((0.5 * (z1 + z2)).mean(dim=0) ** 2).sum()

        return (R0 + R1 + 2 * self.tau * R2) / self.mu, {"part1": R0 / self.mu, "part2": R1 / self.mu, "part3": R2 / self.mu}


class SpectralContrastiveLossV2(nn.Module):
    def __init__(self, mu=1.0, tau=0.0):
        super().__init__()
        self.tau = tau
        self.mu = mu
    
    def forward(self, z1, z2):
        assert z1.shape == z2.shape and len(z1.shape)==2
        device = z1.device
        N = z1.shape[0]
        z = torch.cat([z1, z2], dim=0)
        mask = (z.norm(dim=1) < np.sqrt(self.mu)).float().unsqueeze(1)
        z = mask * z + (1 - mask) * F.normalize(z, dim=1) * np.sqrt(self.mu)

        logits = torch.matmul(z, z.T)
        # mask1 = (z1.norm(dim=1) < np.sqrt(self.mu)).float().unsqueeze(1)
        # mask2 = (z2.norm(dim=1) < np.sqrt(self.mu)).float().unsqueeze(1)
        # z1 = mask1 * z1 + (1 - mask1) * F.normalize(z1, dim=-1) * np.sqrt(self.mu)
        # z2 = mask2 * z2 + (1 - mask2) * F.normalize(z2, dim=1) * np.sqrt(self.mu)

        # logits = torch.matmul(z1, z2.T)
        # mask-out negative pairs
        mask = torch.eye(N, dtype=torch.float32, device=device).repeat(2, 2)

        loss_part1 = -2 * (1 - self.tau) * (logits * mask).sum() / mask.sum()
        loss_part2 = (logits ** 2).mean()
        # loss_part2 = self.mu * ((logits + self.tau) ** 2).mean()
        return loss_part1 + loss_part2, {"part1": loss_part1, "part2": loss_part2}


class SupervisedSpectralContrastiveLoss(nn.Module):
    def __init__(self, mu=1.0, tau=0.0):
        super().__init__()
        self.tau = tau
        self.mu = mu
    
    def forward(self, z1, z2, labels):
        assert z1.shape == z2.shape and len(z1.shape) == 2
        device = z1.device
        N = z1.shape[0]

        z = torch.cat([z1, z2], dim=0)
        z = F.normalize(z, dim=1)
        # mask = (z.norm(dim=1) < np.sqrt(self.mu)).float().unsqueeze(1)
        # z = mask * z + (1 - mask) * F.normalize(z, dim=1) * np.sqrt(self.mu)

        logits = torch.matmul(z, z.T)
        # mask-out negative pairs
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().repeat(2, 2)

        loss_part1 = -2 * (logits * mask).sum() / mask.sum()
        loss_part2 = self.mu * ((logits + self.tau) ** 2).mean()

        return loss_part1 + loss_part2, {"part1": loss_part1, "part2": loss_part2}

        # return (loss_part1 + loss_part2) / self.mu, {"part1": loss_part1 / self.mu, "part2": loss_part2 / self.mu}



class SCELoss(nn.Module):
    def __init__(self, num_classes=10, a=1, b=1):
        super(SCELoss, self).__init__()
        self.num_classes = num_classes
        self.a = a
        self.b = b
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, pred, labels, eps=1e-7):
        ce = self.cross_entropy(pred, labels)
        # RCE
        pred = F.softmax(pred, dim=1)
        pred = torch.clamp(pred, min=eps, max=1.0)
        label_one_hot = F.one_hot(labels, self.num_classes).float().to(pred.device)
        label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
        rce = (-1 * torch.sum(pred * torch.log(label_one_hot), dim=1))

        loss = self.a * ce + self.b * rce.mean()
        return loss
    
class GCELoss(nn.Module):
    def __init__(self, num_classes=10, q=0.7):
        super(GCELoss, self).__init__()
        self.q = q
        self.num_classes = num_classes

    def forward(self, pred, labels, eps=1e-7):
        pred = F.softmax(pred, dim=1)
        pred = torch.clamp(pred, min=eps, max=1.0)
        label_one_hot = F.one_hot(labels, self.num_classes).float().to(pred.device)
        loss = (1. - torch.pow(torch.sum(label_one_hot * pred, dim=1), self.q)) / self.q
        return loss.mean()

class FocalLoss(torch.nn.Module):
    '''
        https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
    '''

    def __init__(self, gamma=0.5, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.Tensor([alpha, 1-alpha])
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)                         # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))    # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = torch.autograd.Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * torch.autograd.Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()

class NLNL(torch.nn.Module):
    def __init__(self, train_loader, num_classes=10, ln_neg=1):
        super(NLNL, self).__init__()
        self.num_classes = num_classes
        self.ln_neg = ln_neg
        weight = torch.FloatTensor(num_classes).zero_() + 1.
        if not hasattr(train_loader.dataset, 'targets'):
            weight = [1] * num_classes
            weight = torch.FloatTensor(weight)
        else:
            for i in range(num_classes):
                weight[i] = (torch.from_numpy(np.array(train_loader.dataset.targets)) == i).sum()
            weight = 1 / (weight / weight.max())
        self.weight = weight.cuda()
        self.criterion = torch.nn.CrossEntropyLoss(weight=self.weight)
        self.criterion_nll = torch.nn.NLLLoss()

    def forward(self, pred, labels):
        labels_neg = (labels.unsqueeze(-1).repeat(1, self.ln_neg)
                      + torch.LongTensor(len(labels), self.ln_neg).cuda().random_(1, self.num_classes)) % self.num_classes
        labels_neg = torch.autograd.Variable(labels_neg)

        assert labels_neg.max() <= self.num_classes-1
        assert labels_neg.min() >= 0
        assert (labels_neg != labels.unsqueeze(-1).repeat(1, self.ln_neg)).sum() == len(labels)*self.ln_neg

        s_neg = torch.log(torch.clamp(1. - F.softmax(pred, 1), min=1e-5, max=1.))
        s_neg *= self.weight[labels].unsqueeze(-1).expand(s_neg.size()).cuda()
        labels = labels * 0 - 100
        loss = self.criterion(pred, labels) * float((labels >= 0).sum())
        loss_neg = self.criterion_nll(s_neg.repeat(self.ln_neg, 1), labels_neg.t().contiguous().view(-1)) * float((labels_neg >= 0).sum())
        loss = ((loss+loss_neg) / (float((labels >= 0).sum())+float((labels_neg[:, 0] >= 0).sum())))
        return loss